{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from keras.datasets import mnist\n",
    "from tensorflow.python.keras.backend import set_session\n",
    "from tensorflow.python.keras.models import load_model\n",
    "from tensorflow.keras.models import Model, load_model, Sequential\n",
    "from tensorflow.keras.layers import Input,Conv2D, Dense, MaxPooling2D, Softmax, Activation, BatchNormalization, Flatten, Dropout, DepthwiseConv2D\n",
    "from tensorflow.keras.layers import MaxPool2D, AvgPool2D, AveragePooling2D, GlobalAveragePooling2D,ZeroPadding2D,Input,Embedding,PReLU,Reshape\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "from keras.callbacks import TensorBoard\n",
    "from keras.utils import np_utils\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "import keras.backend as K\n",
    "import tensorflow as tf\n",
    "import time\n",
    "\n",
    "from os import environ\n",
    "environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train,y_train), (x_test,y_test) = mnist.load_data() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = x_train.reshape(x_train.shape[0],x_train.shape[1],x_train.shape[2],1)/255\n",
    "x_test = x_test.reshape(x_test.shape[0],x_test.shape[1],x_test.shape[2],1)/255\n",
    "\n",
    "y_train = np_utils.to_categorical(y_train,num_classes=10) \n",
    "y_test = np_utils.to_categorical(y_test,num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist_valid\n",
    "def init_model(dim0):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(dim0, (3,3), padding = 'valid',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
    "    model.add(Conv2D(dim0*2, (3,3), padding = 'valid',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
    "    model.add(Conv2D(dim0*4, (3,3), padding = 'valid',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu')); \n",
    "    \n",
    "    model.add(GlobalAveragePooling2D(name='GAP'))\n",
    "    model.add(Dense(10, name=\"fc1\"))\n",
    "    model.add(Activation('softmax', name=\"sm\"))\n",
    "    return model\n",
    "\n",
    "DIM0 = 4\n",
    "model=init_model(DIM0)   \n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist_same\n",
    "def init_model(dim0):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(dim0, (3,3), padding = 'same',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
    "    model.add(Conv2D(dim0*2, (3,3), padding = 'same',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
    "    model.add(Conv2D(dim0*4, (3,3), padding = 'same',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu')); \n",
    "    \n",
    "    model.add(GlobalAveragePooling2D(name='GAP'))\n",
    "    model.add(Dense(10, name=\"fc1\"))\n",
    "    model.add(Activation('softmax', name=\"sm\"))\n",
    "    return model\n",
    "\n",
    "DIM0 = 4\n",
    "model=init_model(DIM0)  \n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist_fc\n",
    "def init_model(dim0):\n",
    "    model = Sequential()\n",
    "    model.add(Flatten(input_shape = (28, 28, 1), name='flatten'))\n",
    "    model.add(Dense(10, name=\"fc3\"))\n",
    "    model.add(Activation('softmax', name=\"sm\"))\n",
    "    return model\n",
    "\n",
    "DIM0 = 4\n",
    "model=init_model(DIM0)   \n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist_dw\n",
    "def init_model(dim0):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(dim0, (3,3), padding = 'same',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0a'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
    "    model.add(DepthwiseConv2D((3,3), padding='same', name='ftr0b'));model.add(BatchNormalization());model.add(Activation('relu', name=\"relu00\")); #32x32\n",
    "    model.add(Conv2D(dim0*4, (3,3), padding = 'same',strides = (2, 2), name='ftr1a'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
    "    model.add(DepthwiseConv2D((3,3), padding = 'same', depth_multiplier=2, name='ftr1b'));model.add(BatchNormalization());model.add(Activation('relu', name=\"relu11\")); \n",
    "    \n",
    "    model.add(GlobalAveragePooling2D(name='GAP'))\n",
    "    model.add(Dense(10, name=\"fc1\"))\n",
    "    model.add(Activation('softmax', name=\"sm\"))\n",
    "    return model\n",
    "\n",
    "DIM0 = 4\n",
    "model=init_model(DIM0)   \n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist_rect\n",
    "def init_model(dim0):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(dim0, (15,3), padding = 'same',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
    "    model.add(Conv2D(dim0*2, (7,3), padding = 'same',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
    "    model.add(Conv2D(dim0*4, (3,3), padding = 'same',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu',name=\"relu2\")); \n",
    "    \n",
    "    model.add(GlobalAveragePooling2D(name='GAP'))\n",
    "    model.add(Dense(10, name=\"fc1\"))\n",
    "    model.add(Activation('softmax', name=\"sm\"))\n",
    "    return model\n",
    "\n",
    "DIM0 = 4\n",
    "model=init_model(DIM0)  \n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist_nnom\n",
    "def init_model(dim0):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(dim0, (3,3), padding = 'same',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
    "    model.add(Conv2D(dim0*2, (3,3), padding = 'same',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
    "    model.add(Conv2D(dim0*4, (3,3), padding = 'same',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu',name=\"relu2\")); \n",
    "    \n",
    "    model.add(Conv2D(dim0*8, (4,4), padding = 'valid', name='ftr3'));\n",
    "    model.add(Reshape((96,),name='reshape'))\n",
    "    model.add(Dense(10, name=\"fc1\"))\n",
    "    \n",
    "    model.add(Activation('softmax', name=\"sm\"))\n",
    "    return model\n",
    "\n",
    "DIM0 = 12\n",
    "model=init_model(DIM0)   \n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist_arduino\n",
    "def init_model(dim0):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(dim0, (3,3), padding = 'valid',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
    "    model.add(Conv2D(dim0*3, (3,3), padding = 'valid',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
    "    model.add(Conv2D(dim0*6, (3,3), padding = 'valid',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu')); \n",
    "    \n",
    "    model.add(GlobalAveragePooling2D(name='GAP'))\n",
    "    model.add(Dense(10, name=\"fc1\"))\n",
    "    model.add(Activation('softmax', name=\"sm\"))\n",
    "    return model\n",
    "\n",
    "DIM0 = 1\n",
    "model=init_model(DIM0)  \n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist resnet\n",
    "def init_model(dim0):\n",
    "    inputs = Input(shape=(28,28,1))\n",
    "    x = Conv2D(dim0, (3,3), padding = 'same',strides = (2, 2), name='ftr0')(inputs)\n",
    "    x = BatchNormalization(name=\"bn0\")(x)\n",
    "    x = Activation('relu', name=\"relu0\")(x)\n",
    "    \n",
    "    x = Conv2D(dim0*2, (3,3), padding = 'same',strides = (2, 2), name='ftr1')(x)\n",
    "    x = BatchNormalization(name=\"bn1\")(x)\n",
    "    x = Activation('relu', name=\"relu1\")(x)\n",
    "    res = x\n",
    "    \n",
    "    x = Conv2D(dim0*2, (3,3), padding = 'same', name='ftr2')(x)\n",
    "    x = BatchNormalization(name=\"bn2\")(x)\n",
    "    \n",
    "    x = res + x\n",
    "    \n",
    "    x = Conv2D(dim0*4, (3,3), padding = 'valid',strides = (2, 2), name='ftr3')(x)\n",
    "    x = Flatten(name='reshape')(x)\n",
    "    x = Dense(10, name=\"fc1\")(x)\n",
    "    \n",
    "    x = Activation('softmax', name=\"sm\")(x)\n",
    "    model = Model(inputs = inputs, outputs=x)\n",
    "    \n",
    "    return model\n",
    "\n",
    "DIM0 = 12\n",
    "model=init_model(DIM0)   \n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 20\n",
    "model.compile(optimizer='adam', loss = \"categorical_crossentropy\", metrics = [\"categorical_accuracy\"]) \n",
    "H = model.fit(x_train, y_train, batch_size=64, epochs= EPOCHS,  verbose= 1, validation_data = (x_test, y_test), shuffle=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "h5_file = \"mnist_resnet.h5\"\n",
    "model.save(h5_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = x_test[1]\n",
    "for y in range(28):\n",
    "    for x in range(28):\n",
    "        print(\"%3d,\"%(int(data[y,x,0]*255)), end=\"\")\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 190ms/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([9.6214655e-14, 3.2803200e-18, 1.0000000e+00, 3.8382069e-22,\n",
       "       5.4997339e-24, 5.6444559e-32, 5.4293467e-14, 3.6996416e-21,\n",
       "       3.2022371e-21, 1.7273456e-21], dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = model.predict(x_test[1:2])[0]\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf29",
   "language": "python",
   "name": "py374_tf21"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
